#script used to generate cluster plots - Figure 4 supplement 3
import MDAnalysis as mda
import MDAnalysis.analysis.align as align
import numpy as np
from scipy.spatial.distance import pdist, squareform
from scipy.cluster.hierarchy import linkage, dendrogram, fcluster
import matplotlib.pyplot as plt

# Load your simulation data (replace with your file names)
u = mda.Universe('/beagle3/wtang/ACE_MD/ACE/ACE_protein_Zn.pdb', '/beagle3/wtang/ACE_MD/ACE/run3/ACErun3_Feb012024_unwrapped.dcd', '/beagle3/wtang/ACE_MD/ACE/run7/ACErun7_Feb012024_unwrapped.dcd', '/beagle3/wtang/ACE_MD/ACE/run8/ACErun8_Feb012024.dcd', '/beagle3/wtang/ACE_MD/ACE/run9/ACErun9_Feb012024_unwrapped.dcd')
print('Simulation files loaded')

# Load PDB structures for comparison, make sure these are aligned to simulation
pdb_files = [
    '/beagle3/wtang/ACE_MD/ACE/Analysis/PCA/365_aligned_apo.pdb',
    '/beagle3/wtang/ACE_MD/ACE/Analysis/PCA/315_aligned_apo.pdb',
    '/beagle3/wtang/ACE_MD/ACE/Analysis/PCA/305_aligned_apo.pdb',
    '/beagle3/wtang/ACE_MD/ACE/Analysis/PCA/299_aligned_apo.pdb'
]
print('PDBs loaded')

# Select atoms for analysis, here we select the alpha carbons
atoms = u.select_atoms("protein and name CA")
print('Atoms selected')

# Align simulation frames to the first frame
def align_to_reference(u, reference, atom_selection):
    ref_atoms = reference.select_atoms(atom_selection)
    mobile_atoms = u.select_atoms(atom_selection)
    alignment = align.AlignTraj(u, reference, select=atom_selection, in_memory=True).run()
    aligned_coordinates = []

    for ts in u.trajectory:
        aligned_coordinates.append(mobile_atoms.positions)

    return np.array(aligned_coordinates)

# Extract aligned coordinates from the simulation data
sim_coordinates = align_to_reference(u, u, "protein and name CA")
print('Simulation coordinates extracted and aligned')

# Extract and align coordinates from PDB structures
pdb_coordinates = []

for pdb_file in pdb_files:
    u_pdb = mda.Universe(pdb_file)
    atoms_pdb = u_pdb.select_atoms("protein and name CA")
    # Align PDB structures to the first simulation frame
    align.AlignTraj(u_pdb, u, select="protein and name CA", in_memory=True).run()
    pdb_coordinates.append(atoms_pdb.positions)

pdb_coordinates = np.array(pdb_coordinates)
print('PDB coordinates extracted and aligned')

# Combine simulation and PDB coordinates
all_coordinates = np.concatenate([sim_coordinates, pdb_coordinates])
print('Coordinates combined')

# Calculate RMSD matrix
def rmsd(a, b):
    delta = a - b
    return np.sqrt((delta**2).sum())

num_structures = all_coordinates.shape[0]
rmsd_matrix = np.zeros((num_structures, num_structures))

for i in range(num_structures):
    for j in range(i, num_structures):
        rmsd_matrix[i, j] = rmsd(all_coordinates[i], all_coordinates[j])
        rmsd_matrix[j, i] = rmsd_matrix[i, j]
print('RMSD matrix calculated')

# Perform hierarchical clustering
distance_matrix = squareform(pdist(all_coordinates.flatten().reshape(num_structures, -1), metric=rmsd))
linkage_matrix = linkage(distance_matrix, method='ward')
dendro = dendrogram(linkage_matrix, labels=np.concatenate([['Sim']*len(sim_coordinates), ['PDB']*len(pdb_coordinates)]))
print('Hierarchical clustering completed')

# Define a function to cluster the data
def cluster_data(linkage_matrix, threshold):
    clusters = fcluster(linkage_matrix, threshold, criterion='distance')
    return clusters
print('Cluster function defined')

# Cluster the data at a specified threshold
threshold = 1.0  # Adjust this value as needed
clusters = cluster_data(linkage_matrix, threshold)
print('Data clustered')

# Plot clusters
plt.figure(figsize=(10, 6))
dendrogram(linkage_matrix, labels=np.concatenate([['Sim']*len(sim_coordinates), ['PDB']*len(pdb_coordinates)]))
plt.axhline(y=threshold, color='r', linestyle='--')
plt.title('Hierarchical Clustering Dendrogram')
plt.xlabel('Index')
plt.ylabel('Distance')
plt.savefig('Apo_cluster.png', dpi=300, bbox_inches='tight')
print('Data plotted')

# Print cluster assignments
for i in range(len(clusters)):
    label = 'Sim' if i < len(sim_coordinates) else 'PDB'
    print(f'Structure {i} ({label}): Cluster {clusters[i]}')

# Optionally, save representative structures from each cluster
# (Here, we save the first structure from each cluster as representative)
unique_clusters = np.unique(clusters)
for cluster_id in unique_clusters:
    representative_idx = np.where(clusters == cluster_id)[0][0]
    u.trajectory[representative_idx]  # Set trajectory to the representative frame
    u.atoms.write(f'cluster_{cluster_id}_representative.pdb')
print('Representative structures saved.')

print('Cluster analysis completed.')
